{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "RDKit WARNING: [14:02:01] Enabling RDKit 2019.09.1 jupyter extensions\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<torch._C.Generator at 0x7f9b8d51b270>"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from torch_geometric.data import DataLoader\n",
    "import torch_geometric\n",
    "import torch.distributions as D\n",
    "import matplotlib.pyplot as plt\n",
    "from rdkit import Chem, DataStructs\n",
    "from rdkit.Chem import AllChem, Draw, rdFMCS, rdMolTransforms\n",
    "from rdkit.Chem.rdMolAlign import AlignMol\n",
    "from rdkit.Chem import PandasTools\n",
    "from rdkit import rdBase\n",
    "import glob\n",
    "import os\n",
    "\n",
    "import deepdock\n",
    "from deepdock.utils.distributions import *\n",
    "from deepdock.utils.data import *\n",
    "from deepdock.models import *\n",
    "\n",
    "%matplotlib inline\n",
    "np.random.seed(123)\n",
    "torch.cuda.manual_seed_all(123)\n",
    "torch.manual_seed(123)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<All keys matched successfully>"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
    "\n",
    "ligand_model = LigandNet(28, residual_layers=10, dropout_rate=0.10)\n",
    "target_model = TargetNet(4, residual_layers=10, dropout_rate=0.10)\n",
    "model = DeepDock(ligand_model, target_model, hidden_dim=64, n_gaussians=10, dropout_rate=0.10, dist_threhold=7.).to(device)\n",
    "\n",
    "checkpoint = torch.load(deepdock.__path__[0]+'/../Trained_models/DeepDock_pdbbindv2019_13K_minTestLoss.chk', map_location=torch.device(device))\n",
    "model.load_state_dict(checkpoint['model_state_dict']) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Complexes from pdbBind: 285\n"
     ]
    }
   ],
   "source": [
    "db_complex = PDBbind_complex_dataset(data_path=deepdock.__path__[0]+'/../data/dataset_CASF-2016_285.tar', \n",
    "                                     min_target_nodes=None, max_ligand_nodes=None)\n",
    "print('Complexes from pdbBind:', len(db_complex))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "class input_dataset(Dataset):\n",
    "    def __init__(self, mols, target_mesh, labels=None, transform=None, pre_transform=None):\n",
    "        super(input_dataset, self).__init__()\n",
    "        \n",
    "        self.mols = [from_networkx(mol2graph.mol_to_nx(m)) for m in mols]\n",
    "        self.target_mesh = target_mesh\n",
    "        self.labels = labels\n",
    "        if labels is None:\n",
    "            self.labels = range(len(self.mols))\n",
    "        \n",
    "    def len(self):\n",
    "        return len(self.mols)\n",
    "\n",
    "    def get(self, idx):\n",
    "        return self.mols[idx], self.target_mesh, self.labels[idx]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 6h 54min 8s, sys: 9min 26s, total: 7h 3min 35s\n",
      "Wall time: 4h 25min 33s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "import glob\n",
    "from torch_scatter import scatter_add\n",
    "target_files = [f.split('/')[-1] for f in glob.glob(deepdock.__path__[0]+'/../data/CASF-2016/decoys_screening/*', recursive=False)]\n",
    "results = []\n",
    "\n",
    "\n",
    "for target_data in db_complex:\n",
    "    model.eval()\n",
    "    ligand, target, _, pdbid = target_data\n",
    "    if pdbid in target_files: \n",
    "        decoy_files = [f.split('/')[-1] for f in glob.glob(deepdock.__path__[0]+'/../data/CASF-2016/decoys_screening/'+pdbid+'/*.mol2', recursive=False)]\n",
    "    else:\n",
    "        continue\n",
    "        \n",
    "    for file in decoy_files:\n",
    "        decoys = Mol2MolSupplier(file=deepdock.__path__[0]+'/../data/CASF-2016/decoys_screening/'+pdbid+'/'+file, \n",
    "                                 sanitize=False, cleanupSubstructures=False)\n",
    "        decoy_names = [m.GetProp('Name') for m in decoys]\n",
    "        db_decoys = input_dataset(decoys, target, decoy_names)\n",
    "        loader_decoys = DataLoader(db_decoys, batch_size=20, shuffle=False)\n",
    "        #print(pdbid)\n",
    "\n",
    "        for data in loader_decoys:\n",
    "            decoy, target_temp, cpd_name = data\n",
    "            decoy, target_temp = decoy.to(device), target_temp.to(device)\n",
    "            pi, sigma, mu, dist, atom_types, bond_types, batch = model(decoy, target_temp)\n",
    "\n",
    "            normal = Normal(mu, sigma)\n",
    "            logprob = normal.log_prob(dist.expand_as(normal.loc))\n",
    "            logprob += torch.log(pi)\n",
    "            prob = logprob.exp().sum(1)\n",
    "            prob_all = scatter_add(prob, batch, dim=0, dim_size=batch.unique().size(0))\n",
    "            \n",
    "            prob[torch.where(dist > 10)[0]] = 0.\n",
    "            prob_10 = scatter_add(prob, batch, dim=0, dim_size=batch.unique().size(0))\n",
    "        \n",
    "            prob[torch.where(dist > 7)[0]] = 0.\n",
    "            prob_7 = scatter_add(prob, batch, dim=0, dim_size=batch.unique().size(0))\n",
    "\n",
    "            prob[torch.where(dist > 5)[0]] = 0.\n",
    "            prob_5 = scatter_add(prob, batch, dim=0, dim_size=batch.unique().size(0))\n",
    "\n",
    "            prob[torch.where(dist > 3)[0]] = 0.\n",
    "            prob_3 = scatter_add(prob, batch, dim=0, dim_size=batch.unique().size(0))\n",
    "        \n",
    "            prob = torch.stack([prob_3, prob_5, prob_7, prob_10, prob_all],dim=1)\n",
    "            #print(pdbid, cpd_name, prob_all.cpu().detach().numpy())\n",
    "            results.append(np.concatenate([np.expand_dims(np.repeat(pdbid, len(cpd_name)), axis=1), \n",
    "                                           np.expand_dims(cpd_name, axis=1), \n",
    "                                           prob.cpu().detach().numpy()], axis=1))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(1624500, 7)\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>PDB_ID</th>\n",
       "      <th>Cpd_Name</th>\n",
       "      <th>Score_3A</th>\n",
       "      <th>Score_5A</th>\n",
       "      <th>Score_7A</th>\n",
       "      <th>Score_10A</th>\n",
       "      <th>Score_all</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>1o3f</td>\n",
       "      <td>4de2_ligand_1</td>\n",
       "      <td>7.254809622430958</td>\n",
       "      <td>89.94999521484337</td>\n",
       "      <td>447.3108304509815</td>\n",
       "      <td>471.33409258768546</td>\n",
       "      <td>471.33795008409413</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1o3f</td>\n",
       "      <td>4de2_ligand_10</td>\n",
       "      <td>2.2435853524573153</td>\n",
       "      <td>64.8291232744619</td>\n",
       "      <td>469.6149662619709</td>\n",
       "      <td>497.22048202034364</td>\n",
       "      <td>497.22063628573727</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>1o3f</td>\n",
       "      <td>4de2_ligand_100</td>\n",
       "      <td>3.451115585675536</td>\n",
       "      <td>78.15551585586073</td>\n",
       "      <td>506.3139331427199</td>\n",
       "      <td>554.3860309102873</td>\n",
       "      <td>554.3878273892042</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>1o3f</td>\n",
       "      <td>4de2_ligand_106</td>\n",
       "      <td>4.784300354428733</td>\n",
       "      <td>102.6058833211317</td>\n",
       "      <td>590.7510395276462</td>\n",
       "      <td>620.5011685339421</td>\n",
       "      <td>620.5046101480829</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>1o3f</td>\n",
       "      <td>4de2_ligand_108</td>\n",
       "      <td>2.349412635601983</td>\n",
       "      <td>76.80465950772964</td>\n",
       "      <td>539.9298935579708</td>\n",
       "      <td>575.3306395604491</td>\n",
       "      <td>575.3335251313035</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "  PDB_ID         Cpd_Name            Score_3A           Score_5A  \\\n",
       "0   1o3f    4de2_ligand_1   7.254809622430958  89.94999521484337   \n",
       "1   1o3f   4de2_ligand_10  2.2435853524573153   64.8291232744619   \n",
       "2   1o3f  4de2_ligand_100   3.451115585675536  78.15551585586073   \n",
       "3   1o3f  4de2_ligand_106   4.784300354428733  102.6058833211317   \n",
       "4   1o3f  4de2_ligand_108   2.349412635601983  76.80465950772964   \n",
       "\n",
       "            Score_7A           Score_10A           Score_all  \n",
       "0  447.3108304509815  471.33409258768546  471.33795008409413  \n",
       "1  469.6149662619709  497.22048202034364  497.22063628573727  \n",
       "2  506.3139331427199   554.3860309102873   554.3878273892042  \n",
       "3  590.7510395276462   620.5011685339421   620.5046101480829  \n",
       "4  539.9298935579708   575.3306395604491   575.3335251313035  "
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "results = np.concatenate(results, axis=0)\n",
    "results = pd.DataFrame(np.asarray(results), columns=['PDB_ID', 'Cpd_Name', 'Score_3A', 'Score_5A', 'Score_7A', 'Score_10A', 'Score_all'])\n",
    "results.to_csv('Score_decoys_screening_CASF2016.csv', index=False)\n",
    "print(results.shape)\n",
    "results.head() "
   ]
  },
  {
   "cell_type": "raw",
   "metadata": {},
   "source": [
    "import pandas as pd\n",
    "\n",
    "df = pd.read_csv('Score_decoys_screening_CASF2016.csv')\n",
    "pdbids = df.PDB_ID.unique()\n",
    "print(len(pdbids))\n",
    "\n",
    "for pdbid in pdbids:\n",
    "\n",
    "    df1 = df[df.PDB_ID==pdbid][['Cpd_Name', 'Score_3A']]\n",
    "    df1.columns= ['#code_ligand_num', 'score']\n",
    "    df1.to_csv('ScreeningPower_DeepDock_3A/scores/'+pdbid+'_score.dat', index=False, sep='\\t')\n",
    "    \n",
    "    df1 = df[df.PDB_ID==pdbid][['Cpd_Name', 'Score_5A']]\n",
    "    df1.columns= ['#code_ligand_num', 'score']\n",
    "    df1.to_csv('ScreeningPower_DeepDock_5A/scores/'+pdbid+'_score.dat', index=False, sep='\\t')\n",
    "    \n",
    "    df1 = df[df.PDB_ID==pdbid][['Cpd_Name', 'Score_7A']]\n",
    "    df1.columns= ['#code_ligand_num', 'score']\n",
    "    df1.to_csv('ScreeningPower_DeepDock_7A/scores/'+pdbid+'_score.dat', index=False, sep='\\t')\n",
    "    \n",
    "    df1 = df[df.PDB_ID==pdbid][['Cpd_Name', 'Score_10A']]\n",
    "    df1.columns= ['#code_ligand_num', 'score']\n",
    "    df1.to_csv('ScreeningPower_DeepDock_10A/scores/'+pdbid+'_score.dat', index=False, sep='\\t')\n",
    "    \n",
    "    df1 = df[df.PDB_ID==pdbid][['Cpd_Name', 'Score_all']]\n",
    "    df1.columns= ['#code_ligand_num', 'score']\n",
    "    df1.to_csv('ScreeningPower_DeepDock_all/scores/'+pdbid+'_score.dat', index=False, sep='\\t')"
   ]
  },
  {
   "cell_type": "raw",
   "metadata": {},
   "source": [
    "python /data/CASF-2016/power_screening/forward_screening_power.py -c /data/CASF-2016/power_screening/CoreSet.dat -s scores -t /data/CASF-2016/power_screening/TargetInfo.dat -p 'positive' -o 'Forward_DeepDock' > ForwardScreeningPower_DeepDock_all.out\n",
    "\n",
    "python /data/CASF-2016/power_screening/reverse_screening_power.py -c /data/CASF-2016/power_screening/CoreSet.dat -s scores -l /data/CASF-2016/power_screening/LigandInfo.dat -p 'positive' -o 'Reverse_DeepDock' > ReverseScreeningPower_DeepDock_all.out\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
